import torch


def add_gumbel(o_t, eps=1e-10, gpu=False):
    """Add o_t by a vector sampled from Gumbel(0,1)"""
    u = torch.zeros(o_t.size())
    if gpu:
        u = u.cuda()

    u.uniform_(0, 1)


    g_t = -torch.log(-torch.log(u + eps) + eps)
    gumbel_t = o_t + g_t
    return gumbel_t
